import argparse
import numpy as np
import torch
import gym
from vae import VAE
#from coolname import generate_slug
import utils
import d4rl
from tqdm import tqdm
import h5py
import torch.nn as nn


# CUDA_VISIBLE_DEVICES=1 python add_score_vae.py --seed=1 --num_iters=1 --dataset umaze
parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--dataset', type=str, default='medium')  # medium, medium-replay, medium-expert, expert
parser.add_argument('--num_iters', type=int, default=int(1e5))
args = parser.parse_args()


device = 'cuda'

# load data
env_name = 'antmaze-'+args.dataset+'-v2'
env = gym.make(env_name)

state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
max_action = float(env.action_space.high[0])

print(state_dim, action_dim, max_action)
latent_dim = action_dim * 2

replay_buffer = utils.ReplayBuffer(state_dim, action_dim)
replay_buffer.convert_selfdata('../implicit_q_learning/datasets/datasets_merge_full_split_1/antmaze-' + args.dataset + '-v2.hdf5')
replay_selfbuffer = utils.ReplayBuffer(state_dim, action_dim)
replay_selfbuffer.convert_selfdata('../implicit_q_learning/datasets/datasets_cvae_full_split_1/antmaze-' + args.dataset + '-v2-finetune.hdf5')

states = replay_buffer.state
actions = replay_buffer.action

vae = VAE(state_dim, action_dim, latent_dim, max_action, hidden_dim=512).to(device)
vae.load_state_dict(torch.load('../implicit_q_learning/models/cvae_full_split_1_varloss/vae_model_antmaze_' + args.dataset + '.pt'))

total_size = states.shape[0]
# calculate center point of expert distribution
states_expert = replay_selfbuffer.state
actions_expert = replay_selfbuffer.action
train_states = torch.from_numpy(states_expert).to(device)
train_actions = torch.from_numpy(actions_expert).to(device)
_, mean_all, std_all = vae(train_states, train_actions)
mean = torch.mean(mean_all, 0)
std = torch.mean(std_all, 0)


scores_ = []
for step in tqdm(range(total_size)):
    states_1 = states[step]
    actions_1 = actions[step]   
    train_states = torch.from_numpy(states_1).to(device)
    train_actions = torch.from_numpy(actions_1).to(device)
    _, mean1, std1 = vae(train_states, train_actions)
    pdist = nn.PairwiseDistance(p=2)
    output = pdist(mean,mean1)

#    KL_loss = 0.5*((std/std1).pow(2)+torch.log((std1/std).pow(2))+((mean-mean1)/std1).pow(2) -1).mean()
#    KL_loss = torch.exp(-KL_loss)
#    if KL_loss.item() >= 1:
#        scores_.append(0.0)
#    else:
#        scores_.append(1.0 - KL_loss.item())
#    scores_.append(KL_loss.item())  

    scores_.append(output.item())
f1 = h5py.File('../implicit_q_learning/datasets/datasets_merge_full_split_1/antmaze-' + args.dataset + '-v2.hdf5', 'r')
f2 = h5py.File('../implicit_q_learning/datasets/datasets_merge_full_split_1/antmaze-' + args.dataset + '-v2-oriscores.hdf5',"w")
f2['observations'] = np.array(f1['observations'][:])
f2['actions'] = np.array(f1['actions'][:])
f2['next_observations'] = np.array(f1['next_observations'][:])
f2['rewards'] = np.array(f1['rewards'][:])
f2['terminals'] = np.array(f1['terminals'][:])
f2['scores'] = np.array(scores_)
f1.close()
f2.close()
